import torch
import numpy as np

from dada.model.torch_model import TorchModel


class PolynomialFeasibleModel(TorchModel):
    def __init__(self,
                 n_features: int,
                 q: float,
                 num_polyhedron: int,
                 a_matrix: torch.Tensor,
                 b_matrix: torch.Tensor,
                 init_point: torch.Tensor = None):
        self.q = q
        self.num_polyhedron = num_polyhedron
        self.a_matrix = a_matrix
        self.b_matrix = b_matrix
        super().__init__(n_features, init_point)

    def loss(self):
        inner_products = torch.matmul(self.a_matrix, self.x)

        # Compute [<a_i, x> - b_i]_+
        residuals = torch.relu(inner_products - self.b_matrix)

        # Raise the residuals to the power of q
        residuals_pow_q = torch.pow(residuals, self.q)

        # Compute the average over all terms
        return torch.mean(residuals_pow_q)

    def compute_value(self, point: np.ndarray):
        a_matrix = self.a_matrix.detach().numpy()
        b_matrix = self.b_matrix.detach().numpy()

        residuals = np.maximum(0, (a_matrix @ point) - b_matrix)
        return np.mean(np.pow(residuals, self.q))

    @staticmethod
    def generate_function_variables(n_features, num_polyhedron, optimal_point):
        a_matrix = []
        for i in range(num_polyhedron):
            ai = np.random.uniform(low=-1, high=1, size=(n_features,))
            if ai.T @ optimal_point > 0:
                ai = -1 * ai

            a_matrix.append(ai)

        c_min = np.min([(ai @ optimal_point) for ai in a_matrix])
        b_matrix = []
        for i in range(num_polyhedron):
            si = np.random.uniform(low=0, high=(-0.1) * c_min)
            bi = a_matrix[i] @ optimal_point + si

            b_matrix.append(bi)

        return torch.from_numpy(np.array(a_matrix)).double(), torch.from_numpy(np.array(b_matrix)).double()
